{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2.0教程-AutoGraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "tf.function的一个很酷的新功能是AutoGraph，它允许使用自然的Python语法编写图形代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.ops import control_flow_util\n",
    "control_flow_util.ENABLE_CONTROL_FLOW_V2 = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.tf.function装饰器\n",
    "当使用tf.function注释函数时，可以像调用任何其他函数一样调用它。 \n",
    "它将被编译成图，这意味着可以获得更快执行，更好地在GPU或TPU上运行或导出到SavedModel。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=25, shape=(3, 3), dtype=float32, numpy=\n",
       "array([[0.75023645, 0.19047515, 0.10737072],\n",
       "       [1.1521267 , 0.49491584, 0.19416495],\n",
       "       [0.5541876 , 0.24642248, 0.09543521]], dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def simple_nn_layer(x, y):\n",
    "    return tf.nn.relu(tf.matmul(x, y))\n",
    "\n",
    "\n",
    "x = tf.random.uniform((3, 3))\n",
    "y = tf.random.uniform((3, 3))\n",
    "\n",
    "simple_nn_layer(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果我们检查注释的结果，我们可以看到它是一个特殊的可调用函数，它处理与TensorFlow运行时的所有交互。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.eager.def_function.Function at 0x7ff5e164eb38>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simple_nn_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果代码使用多个函数，则无需对它们进行全部注释 \n",
    "- 从带注释函数调用的任何函数也将以图形模式运行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=39, shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def linear_layer(x):\n",
    "    return 2 * x + 1\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def deep_net(x):\n",
    "    return tf.nn.relu(linear_layer(x))\n",
    "\n",
    "\n",
    "deep_net(tf.constant((1, 2, 3)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.使用Python控制流程\n",
    "在tf.function中使用依赖于数据的控制流时，可以使用Python控制流语句，AutoGraph会将它们转换为适当的TensorFlow操作。 例如，如果语句依赖于Tensor，则语句将转换为tf.cond（）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square_if_positive(2) = 4\n",
      "square_if_positive(-2) = 0\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def square_if_positive(x):\n",
    "  if x > 0:\n",
    "    x = x * x\n",
    "  else:\n",
    "    x = 0\n",
    "  return x\n",
    "\n",
    "\n",
    "print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))\n",
    "print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "AutoGraph支持常见的Python语句，例如while，if，break，continue和return，支持嵌套。 这意味着可以在while和if语句的条件下使用Tensor表达式，或者在for循环中迭代Tensor。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=149, shape=(), dtype=int32, numpy=42>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def sum_even(items):\n",
    "  s = 0\n",
    "  for c in items:\n",
    "    if c % 2 > 0:\n",
    "      continue\n",
    "    s += c\n",
    "  return s\n",
    "\n",
    "\n",
    "sum_even(tf.constant([10, 12, 15, 20]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "AutoGraph还为高级用户提供了低级API。 例如，我们可以使用它来查看生成的代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from __future__ import print_function\n",
      "\n",
      "def tf__sum_even(items):\n",
      "  do_return = False\n",
      "  retval_ = None\n",
      "  s = 0\n",
      "\n",
      "  def loop_body(loop_vars, s_2):\n",
      "    c = loop_vars\n",
      "    continue_ = False\n",
      "    cond = c % 2 > 0\n",
      "\n",
      "    def if_true():\n",
      "      continue_ = True\n",
      "      return continue_\n",
      "\n",
      "    def if_false():\n",
      "      return continue_\n",
      "    continue_ = ag__.if_stmt(cond, if_true, if_false)\n",
      "    cond_1 = ag__.not_(continue_)\n",
      "\n",
      "    def if_true_1():\n",
      "      s_1, = s_2,\n",
      "      s_1 += c\n",
      "      return s_1\n",
      "\n",
      "    def if_false_1():\n",
      "      return s_2\n",
      "    s_2 = ag__.if_stmt(cond_1, if_true_1, if_false_1)\n",
      "    return s_2,\n",
      "  s, = ag__.for_stmt(items, None, loop_body, (s,))\n",
      "  do_return = True\n",
      "  retval_ = s\n",
      "  return retval_\n",
      "\n",
      "\n",
      "\n",
      "tf__sum_even.autograph_info__ = {}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tf.autograph.to_code(sum_even.python_function, experimental_optional_features=None))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个更复杂的控制流程的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fizz\n",
      "1\n",
      "2\n",
      "Fizz\n",
      "4\n",
      "Buzz\n",
      "Fizz\n",
      "7\n",
      "8\n",
      "Fizz\n",
      "Buzz\n",
      "11\n",
      "Fizz\n",
      "13\n",
      "14\n",
      "\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def fizzbuzz(n):\n",
    "  msg = tf.constant('')\n",
    "  for i in tf.range(n):\n",
    "    if tf.equal(i % 3, 0):\n",
    "      msg += 'Fizz'\n",
    "    elif tf.equal(i % 5, 0):\n",
    "      msg += 'Buzz'\n",
    "    else:\n",
    "      msg += tf.as_string(i)\n",
    "    msg += '\\n'\n",
    "  return msg\n",
    "\n",
    "\n",
    "print(fizzbuzz(tf.constant(15)).numpy().decode())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.Keras和AutoGraph\n",
    "也可以将tf.function与对象方法一起使用。 例如，可以通过注释模型的调用函数来装饰自定义Keras模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=281, shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CustomModel(tf.keras.models.Model):\n",
    "\n",
    "    @tf.function\n",
    "    def call(self, input_data):\n",
    "        if tf.reduce_mean(input_data) > 0:\n",
    "            return input_data\n",
    "        else:\n",
    "            return input_data // 2\n",
    "\n",
    "\n",
    "model = CustomModel()\n",
    "\n",
    "model(tf.constant([-2, -4]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "副作用\n",
    "就像在eager模式下一样，你可以使用带有副作用的操作，比如通常在tf.function中的tf.assign或tf.print，它会插入必要的控件依赖项以确保它们按顺序执行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = tf.Variable(5)\n",
    "\n",
    "@tf.function\n",
    "def find_next_odd():\n",
    "  v.assign(v + 1)\n",
    "  if tf.equal(v % 2, 0):\n",
    "    v.assign(v + 1)\n",
    "\n",
    "\n",
    "find_next_odd()\n",
    "v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.用AutoGraph训练一个简单模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 10 : loss 1.85892391 ; accuracy 0.37\n",
      "Step 20 : loss 1.25817835 ; accuracy 0.5125\n",
      "Step 30 : loss 0.871798933 ; accuracy 0.605666637\n",
      "Step 40 : loss 0.669722676 ; accuracy 0.66\n",
      "Step 50 : loss 0.413253099 ; accuracy 0.6976\n",
      "Step 60 : loss 0.340167761 ; accuracy 0.728\n",
      "Step 70 : loss 0.438654482 ; accuracy 0.748428583\n",
      "Step 80 : loss 0.346158206 ; accuracy 0.766875\n",
      "Step 90 : loss 0.241747677 ; accuracy 0.780555546\n",
      "Step 100 : loss 0.229299903 ; accuracy 0.7935\n",
      "Step 110 : loss 0.300931275 ; accuracy 0.803727269\n",
      "Step 120 : loss 0.369899929 ; accuracy 0.812583327\n",
      "Step 130 : loss 0.305111647 ; accuracy 0.82\n",
      "Step 140 : loss 0.396656752 ; accuracy 0.825857162\n",
      "Step 150 : loss 0.308267832 ; accuracy 0.831266642\n",
      "Step 160 : loss 0.323994666 ; accuracy 0.836312473\n",
      "Step 170 : loss 0.264144391 ; accuracy 0.840941191\n",
      "Step 180 : loss 0.450227708 ; accuracy 0.844333351\n",
      "Step 190 : loss 0.213473886 ; accuracy 0.848105252\n",
      "Step 200 : loss 0.224886 ; accuracy 0.85145\n",
      "Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.224886, shape=(), dtype=float32) ; accuracy tf.Tensor(0.85145, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "def prepare_mnist_features_and_labels(x, y):\n",
    "  x = tf.cast(x, tf.float32) / 255.0\n",
    "  y = tf.cast(y, tf.int64)\n",
    "  return x, y\n",
    "\n",
    "def mnist_dataset():\n",
    "  (x, y), _ = tf.keras.datasets.mnist.load_data()\n",
    "  ds = tf.data.Dataset.from_tensor_slices((x, y))\n",
    "  ds = ds.map(prepare_mnist_features_and_labels)\n",
    "  ds = ds.take(20000).shuffle(20000).batch(100)\n",
    "  return ds\n",
    "\n",
    "train_dataset = mnist_dataset()\n",
    "model = tf.keras.Sequential((\n",
    "    tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(100, activation='relu'),\n",
    "    tf.keras.layers.Dense(100, activation='relu'),\n",
    "    tf.keras.layers.Dense(10)))\n",
    "model.build()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "\n",
    "compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "\n",
    "\n",
    "def train_one_step(model, optimizer, x, y):\n",
    "  with tf.GradientTape() as tape:\n",
    "    logits = model(x)\n",
    "    loss = compute_loss(y, logits)\n",
    "\n",
    "  grads = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "  compute_accuracy(y, logits)\n",
    "  return loss\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def train(model, optimizer):\n",
    "  train_ds = mnist_dataset()\n",
    "  step = 0\n",
    "  loss = 0.0\n",
    "  accuracy = 0.0\n",
    "  for x, y in train_ds:\n",
    "    step += 1\n",
    "    loss = train_one_step(model, optimizer, x, y)\n",
    "    if tf.equal(step % 10, 0):\n",
    "      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())\n",
    "  return step, loss, accuracy\n",
    "\n",
    "step, loss, accuracy = train(model, optimizer)\n",
    "print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.关于批处理的说明\n",
    "在实际应用中，批处理对性能至关重要。 转换为AutoGraph的最佳代码是在批处理级别决定控制流的代码。 如果在单个示例级别做出决策，请尝试使用批处理API来维护性能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def square_if_positive(x):\n",
    "  return [i ** 2 if i > 0 else i for i in x]\n",
    "\n",
    "\n",
    "square_if_positive(range(-5, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1544, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 在tensorflow中上面的代码应该改成下面所示\n",
    "@tf.function\n",
    "def square_if_positive_naive(x):\n",
    "  result = tf.TensorArray(tf.int32, size=x.shape[0])\n",
    "  for i in tf.range(x.shape[0]):\n",
    "    if x[i] > 0:\n",
    "      result = result.write(i, x[i] ** 2)\n",
    "    else:\n",
    "      result = result.write(i, x[i])\n",
    "  return result.stack()\n",
    "\n",
    "\n",
    "square_if_positive_naive(tf.range(-5, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1554, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 也可以怎么写\n",
    "def square_if_positive_vectorized(x):\n",
    "  return tf.where(x > 0, x ** 2, x)\n",
    "\n",
    "\n",
    "square_if_positive_vectorized(tf.range(-5, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
